-
Notifications
You must be signed in to change notification settings - Fork 13
Update custom_ops.py #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Update custom_ops.py #315
Conversation
Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #315 +/- ##
==========================================
- Coverage 96.51% 96.50% -0.01%
==========================================
Files 32 32
Lines 3439 3436 -3
==========================================
- Hits 3319 3316 -3
Misses 120 120 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmulderdas - is having these functions jitted blocking you jitting a function which calls them in s2ai
? Generally jax.jit
wrapped functions should be able to be arbitrarily nested and there is some suggestion it can help with compilation times when a lower-level function is reused multiple times in a higher level function. I suspect removing the jit
decorators on the utility functions here won't have any major performance implications but it would be good to understand why its causing issues as we widely apply jax.jit
to other functions in s2fft
which are likely to be contained in higher-level jitted functions.
@@ -86,7 +86,7 @@ def wigner_subset_to_s2( | |||
return np.fft.ifft(x, axis=-2, norm="forward") | |||
|
|||
|
|||
@partial(jit, static_argnums=(3, 4)) | |||
# @partial(jit, static_argnums=(3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# @partial(jit, static_argnums=(3, 4)) |
We generally shouldn't comment out code as we can always recover snippets from git history - this is likely to be what is causing the linting failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These jitted functions are called in the lifted convolution layers in s2ai
. For performance and general functionality reasons it is desirable to have the flax model, build from these layers, be traceable and jittable at the top level. These utility functions specifically cause errors when trying to trace/jit at the aforementioned level. It is not clear to me why the other imported jitted functions from s2fft
or the ones natively defined in s2ai
don't break in the same way.
@@ -209,7 +209,7 @@ def so3_to_wigner_subset( | |||
return s2_to_wigner_subset(x, spins, DW, L, sampling) | |||
|
|||
|
|||
@partial(jit, static_argnums=(3, 4, 5)) | |||
# @partial(jit, static_argnums=(3, 4, 5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# @partial(jit, static_argnums=(3, 4, 5)) |
@@ -338,7 +338,7 @@ def s2_to_wigner_subset( | |||
return x * (2.0 * np.pi) ** 2 | |||
|
|||
|
|||
@partial(jit, static_argnums=(3, 4)) | |||
# @partial(jit, static_argnums=(3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# @partial(jit, static_argnums=(3, 4)) |
Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.